{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1797, 64)\n",
      "(1797,)\n"
     ]
    }
   ],
   "source": [
    "digits = load_digits()\n",
    "x_data = digits.data\n",
    "y_data = digits.target\n",
    "print(x_data.shape)\n",
    "print(y_data.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAKr0lEQVR4nO3d34tc9RnH8c+nq9Kq0YXUFs2GroIEpNCNhIAEhMa2xComF71IQCFSyJWiNCDaq+QfkPSiCEvUDZgqbVQQsVpBFyu01iRuq3FjSUJKttFGKcFooSH69GInEO3a/c7M+bVP3y9Y3Jkdcp4heXvOnJ05X0eEAOTxtbYHAFAtogaSIWogGaIGkiFqIJmL6vhDbac8pT4yMtLo9q655prGtrV8+fLGtnXmzJnGtnX06NHGttW0iPBC99cSdVbLli1rdHvbt29vbFtbt25tbFvT09ONbWvTpk2NbasrOPwGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIpitr2Btvv2T5i+8G6hwIwuEWjtj0i6ZeSbpV0g6Qttm+oezAAgynZU6+VdCQijkXEWUlPSdpY71gABlUS9QpJJy64Pde77wtsb7O93/b+qoYD0L+ST2kt9PGu//poZURMSpqU8n70ElgKSvbUc5JWXnB7TNLJesYBMKySqN+UdL3ta21fImmzpOfqHQvAoBY9/I6Ic7bvkfSSpBFJj0XEodonAzCQoiufRMQLkl6oeRYAFeAdZUAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyrNDRh6mpqUa3t3Fjcx+G27lzZ2PbanI1kCa3JTX/b2Qh7KmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimZIWOx2yfsv1OEwMBGE7JnnpK0oaa5wBQkUWjjojXJP2zgVkAVKCyT2nZ3iZpW1V/HoDBVBY1y+4A3cDZbyAZogaSKfmV1pOS/iBple052z+tfywAgypZS2tLE4MAqAaH30AyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyS37ZnfHx8ca21eQyOJK0Z8+exra1Y8eOxrY1Ojra2LYmJiYa21ZXsKcGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiCZkmuUrbT9qu1Z24ds39fEYAAGU/Le73OStkfEQdvLJB2w/XJEvFvzbAAGULLszvsRcbD3/RlJs5JW1D0YgMH09Skt2+OSVkt6Y4GfsewO0AHFUdu+XNLTku6PiI+//HOW3QG6oejst+2LNR/03oh4pt6RAAyj5Oy3JT0qaTYiHq5/JADDKNlTr5N0l6T1tmd6Xz+ueS4AAypZdud1SW5gFgAV4B1lQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSSz5NfSOn36dNsj1GZqaqrtEWqR+e+sC9hTA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJlFx48Ou2/2T7z71ld3Y2MRiAwZS8TfTfktZHxCe9SwW/bvu3EfHHmmcDMICSCw+GpE96Ny/ufXGxfqCjSi/mP2J7RtIpSS9HxILL7tjeb3t/1UMCKFcUdUR8FhETksYkrbX93QUeMxkRayJiTdVDAijX19nviDgtaVrShlqmATC0krPfV9ke7X3/DUk/kHS47sEADKbk7PfVkvbYHtH8/wR+HRHP1zsWgEGVnP3+i+bXpAawBPCOMiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSWfLL7kxMTLQ9AtAp7KmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimOOreBf3fss1FB4EO62dPfZ+k2boGAVCN0mV3xiTdJml3veMAGFbpnnqXpAckff5VD2AtLaAbSlbouF3SqYg48L8ex1paQDeU7KnXSbrD9nFJT0lab/uJWqcCMLBFo46IhyJiLCLGJW2W9EpE3Fn7ZAAGwu+pgWT6upxRRExrfilbAB3FnhpIhqiBZIgaSIaogWSIGkiGqIFkiBpIZskvuzMzM9P2CLW58sorG9vW6OhoY9tqcqmkHTt2NLatrmBPDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkVvE+1dSfSMpM8kneMywEB39fPe7+9HxEe1TQKgEhx+A8mURh2Sfmf7gO1tCz2AZXeAbig9/F4XESdtf0vSy7YPR8RrFz4gIiYlTUqS7ah4TgCFivbUEXGy999Tkp6VtLbOoQAMrmSBvMtsLzv/vaQfSXqn7sEADKbk8Pvbkp61ff7xv4qIF2udCsDAFo06Io5J+l4DswCoAL/SApIhaiAZogaSIWogGaIGkiFqIBmiBpJxRPVv08763u/p6em2R6jN8ePH2x6hFlu3bm17hNpEhBe6nz01kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJFEVte9T2PtuHbc/avqnuwQAMpvS637+Q9GJE/MT2JZIurXEmAENYNGrbV0i6WdJWSYqIs5LO1jsWgEGVHH5fJ+lDSY/bfsv27t71v7+AZXeAbiiJ+iJJN0p6JCJWS/pU0oNfflBETEbEGpa5BdpVEvWcpLmIeKN3e5/mIwfQQYtGHREfSDphe1XvrlskvVvrVAAGVnr2+15Je3tnvo9Juru+kQAMoyjqiJiRxGtlYAngHWVAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJMNaWn0YHR1tdHu7du1qbFsTExONbavJ9a1mZmYa21bTWEsL+D9B1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0ks2jUtlfZnrng62Pb9zcxHID+LXqNsoh4T9KEJNkekfR3Sc/WPBeAAfV7+H2LpKMR8bc6hgEwvNJLBJ+3WdKTC/3A9jZJ24aeCMBQivfUvWt+3yHpNwv9nGV3gG7o5/D7VkkHI+IfdQ0DYHj9RL1FX3HoDaA7iqK2famkH0p6pt5xAAyrdNmdf0laXvMsACrAO8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKauZXc+lNTvxzO/KemjyofphqzPjefVnu9ExFUL/aCWqAdhe3/WT3hlfW48r27i8BtIhqiBZLoU9WTbA9Qo63PjeXVQZ15TA6hGl/bUACpA1EAynYja9gbb79k+YvvBtuepgu2Vtl+1PWv7kO372p6pSrZHbL9l+/m2Z6mS7VHb+2wf7v3d3dT2TP1q/TV1b4GAv2r+cklzkt6UtCUi3m11sCHZvlrS1RFx0PYySQckbVrqz+s82z+TtEbSFRFxe9vzVMX2Hkm/j4jdvSvoXhoRp9ueqx9d2FOvlXQkIo5FxFlJT0na2PJMQ4uI9yPiYO/7M5JmJa1od6pq2B6TdJuk3W3PUiXbV0i6WdKjkhQRZ5da0FI3ol4h6cQFt+eU5B//ebbHJa2W9Ea7k1Rml6QHJH3e9iAVu07Sh5Ie77202G37sraH6lcXovYC96X5PZvtyyU9Len+iPi47XmGZft2Saci4kDbs9TgIkk3SnokIlZL+lTSkjvH04Wo5yStvOD2mKSTLc1SKdsXaz7ovRGR5fLK6yTdYfu45l8qrbf9RLsjVWZO0lxEnD+i2qf5yJeULkT9pqTrbV/bOzGxWdJzLc80NNvW/Guz2Yh4uO15qhIRD0XEWESMa/7v6pWIuLPlsSoRER9IOmF7Ve+uWyQtuROb/S6QV7mIOGf7HkkvSRqR9FhEHGp5rCqsk3SXpLdtz/Tu+3lEvNDiTFjcvZL29nYwxyTd3fI8fWv9V1oAqtWFw28AFSJqIBmiBpIhaiAZogaSIWogGaIGkvkPFqN95o/i4FcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(digits.images[10],cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "#数据拆分\n",
    "x_train,x_test,y_train,y_test = train_test_split(x_data,y_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#构建模型，2个隐藏层，第一隐藏层100个神经元，第二隐藏层50个神经元，训练500周期\n",
    "mlp = MLPClassifier(hidden_layer_sizes=(100,50),max_iter=500)\n",
    "mlp.fit(x_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        52\n",
      "           1       0.90      1.00      0.95        47\n",
      "           2       1.00      1.00      1.00        43\n",
      "           3       1.00      0.98      0.99        44\n",
      "           4       1.00      0.97      0.99        38\n",
      "           5       0.96      1.00      0.98        48\n",
      "           6       0.98      0.96      0.97        53\n",
      "           7       0.98      1.00      0.99        47\n",
      "           8       0.95      0.88      0.91        40\n",
      "           9       1.00      0.95      0.97        38\n",
      "\n",
      "    accuracy                           0.98       450\n",
      "   macro avg       0.98      0.97      0.97       450\n",
      "weighted avg       0.98      0.98      0.98       450\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions = mlp.predict(x_test)\n",
    "print(classification_report(y_test,predictions))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
